Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfloat16 #17031

Merged
merged 6 commits into from
Jan 22, 2024

Conversation

snadampal
Copy link
Contributor

@snadampal snadampal commented Aug 7, 2023

Description

This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to implement matrix multiplication with bfloat16 SIMD instructions (bfmmla) and MatMul operator changes to invoke the Sbgemm kernel. To enable Sbgemm kernel, set the following session option:
"kOrtSessionOptionsGemmFastMathMode"

The PR also adds new test cases for mlas and ort.

Motivation and Context

This is to improve MatMul performance on aarch64 platform.
I have run the below benchmarking script (bert , roberta and gpt2 model inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x -1.76x performance improvement compared to sgemm (fp32) kernel performance.

cd onnxruntime/python/tools/transformers
python3 benchmark.py

And the unit test precision results are matching to sgemm kernel results.
./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync

@snadampal
Copy link
Contributor Author

appreciate if someone can review this PR.

@snadampal
Copy link
Contributor Author

Hi @snnn , would you be able to review and provide feedback on this PR? appreciate your time.

@snadampal
Copy link
Contributor Author

Hi, I have rebased the PR to resolve the merge conflicts. I'm happy to address any feedback you may have. Thank you!

@milpuz01
Copy link
Contributor

I have checked out the changes and run performance test and accuracy tests with and without flag using onnxruntime_perf_test (modified the binary to dump output for comparisons) on AWS Graviton3 instances and it was fine.

@snadampal snadampal force-pushed the sbgemm_aarch64 branch 4 times, most recently from eb257ff to 83a6f6e Compare October 4, 2023 19:29
@snadampal
Copy link
Contributor Author

Hi @chenfucn , @yufenglee , I have updated the PR (1) to move to the newer gemm interface and (2) to add session option based fastmath mode control. Please review and let me know your feedback.

@snadampal
Copy link
Contributor Author

Hi @chenfucn , @yufengle, appreciate if someone can trigger the CI for this PR. I have addressed all the feedback except the windows testing for which I'm waiting for the Windows CI results. Thank you!

Copy link
Contributor

@chenfucn chenfucn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, please add mlas unit tests that call the kernel directly with different shapes are other parameters.

onnxruntime/core/providers/cpu/math/matmul.cc Outdated Show resolved Hide resolved
onnxruntime/test/providers/base_tester.cc Outdated Show resolved Hide resolved
@chenfucn
Copy link
Contributor

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, Linux QNN CI Pipeline, MacOS CI Pipeline

@chenfucn
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline, Windows ARM64 QNN CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@snadampal
Copy link
Contributor Author

Thanks for the review, I will update the PR to address this and also add unit tests.

@snadampal
Copy link
Contributor Author

snadampal commented Oct 25, 2023

I have updated the PR to address all the feedback so far and also the learnings from my other qgemm PR.
(1) added the feature only for not Apple
(2) added mlas unit tests
(3) tested linux full build (both release and release with debug info)
(4) minimal build
(5) android build with cross compilation on x86. and (5) lintrunner and git-clang-format

Next, adding ort optimizer and provider tests to test the fastmath session.
Please review and let me know if any feedback on this version.

@snnn
Copy link
Member

snnn commented Jan 17, 2024

@snadampal did you push your change to Github?

@snadampal
Copy link
Contributor Author

snadampal commented Jan 17, 2024

not yet,planning to push the code format changes along with the session name change

@snadampal
Copy link
Contributor Author

Hi @skottmckay , appreciate your response on this.

Hi @skottmckay , if you could comment on the naming part, I will update the PR. the following is what I think should cover your suggestion. You had suggested replacing session with mlas. I think session makes sense because the config is still session specific, so I left it, but added mlas to it.

static const char* const kOrtSessionOptionsMlasGemmFastMathMode = "session.enable_mlas_gemm_fastmath_mode";

@skottmckay
Copy link
Contributor

Hi @skottmckay Scott McKay, appreciate your response on this.

Hi @skottmckay

        Scott McKay
        , if you could comment on the naming part, I will update the PR. the following is what I think should cover your suggestion. You had suggested replacing session with mlas. I think session makes sense because the config is still session specific, so I left it, but added mlas to it.

static const char* const kOrtSessionOptionsMlasGemmFastMathMode = "session.enable_mlas_gemm_fastmath_mode";

I think I would consider the first name as something that points me to where I would find the setting being used. e.g. 'optimization' means look in the optimizer project. I would say it's inferred you're configuring something in the session as you're using SessionOptions (vs. say RunOptions). Based on that, I would vote for 'mlas.' as the prefix.

The name also seems a little too generic as it sounds like it would apply to MLAS as a whole. Unless we think there will be some other fastpath that applies to MLAS GEMM in general, a more specific name would be clearer. e.g. mlas.enable_gemm_fastmath_arm64_bfloat16

Or alternatively the platform/datatype could be in the value and you could parse that.

e.g. mlas.enable_gemm_fastmath_mode could have a value of arm64.bfloat16 and additional platform.data_type values could be parsed for. That is obviously more complicated so we should avoid unless we think it would be used. Possibly required if we had this GEMM fastpath on multiple platforms or for multiple data types though as I assume I'd want to be able to enable/disable each specific combination, and having a new config key for every single combination doesn't scale well.

@snadampal
Copy link
Contributor Author

thank you, I see your point. bf16 and f16 are the potential fastmath options, but on aarch64, so far I see interest for bf16 fastmath alone. I agree that there may not be multiple of these for different platforms, so I will go ahead with a simple config key.

static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

Added SbgemmKernel assembly implementation with bfmmla instructions and
sbgemm utility functions to prepack Matrix B along with conversion to bfloat16.
sbgemm kernel is invoked when fastmath mode is enabled and HW supports
the bf16 instruction set. It's disabled by default, set the following
session option to 1 to enable it.
"kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16"
@snadampal
Copy link
Contributor Author

Update the PR for the session name and other points discussed so far including the clang-formatting. Tested

  1. release, debug and minimal builds on aarch64 neoverse v1 and n1 platforms
  2. android build and linux cross compilation for aarch64 config on x86 platform

@chenfucn
Copy link
Contributor

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, Linux QNN CI Pipeline, MacOS CI Pipeline, Windows ARM64 QNN CI Pipeline

@chenfucn
Copy link
Contributor

/azp run Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows x64 QNN CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

Copy link

Azure Pipelines successfully started running 8 pipeline(s).

1 similar comment
Copy link

Azure Pipelines successfully started running 8 pipeline(s).

@snnn snnn merged commit 77da2ef into microsoft:main Jan 22, 2024
55 of 56 checks passed
@snadampal
Copy link
Contributor Author

Thanks to @chenfucn , @snnn , @skottmckay and @yufenglee for the great feedback and merging the PR!

YUNQIUGUO pushed a commit that referenced this pull request Jan 23, 2024
…oat16 (#17031)

### Description
This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to
implement matrix multiplication with bfloat16 SIMD instructions (bfmmla)
and MatMul operator changes to invoke the Sbgemm kernel. To enable
Sbgemm kernel, set the following session option:
"kOrtSessionOptionsGemmFastMathMode"

The PR also adds new test cases for mlas and ort.

### Motivation and Context

This is to improve MatMul performance on aarch64 platform.
I have run the below benchmarking script (bert , roberta and gpt2 model
inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x
-1.76x performance improvement compared to sgemm (fp32) kernel
performance.

```
cd onnxruntime/python/tools/transformers
python3 benchmark.py
```
And the unit test precision results are matching to sgemm kernel
results.
`./build.sh --config RelWithDebInfo --build_shared_lib --parallel
--compile_no_warning_as_error --skip_submodule_sync `
@snnn
Copy link
Member

snnn commented Jan 24, 2024

@snadampal , thanks for making ONNX Runtime better. Welcome to bring more changes to us. You have my email. Do not hesitate to contact me anytime when you need help on reviewing PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants